{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "cfbbf17a-5bda-4d39-a153-437feaee98ec",
   "metadata": {},
   "source": [
    "# 简单的 PyTorch 示例"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f483a0f4-88c8-45f7-a47e-e2408076011c",
   "metadata": {},
   "source": [
    "## 1. 导入必要的库"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "78e87915-58f5-45a7-9eba-10ac9c329f4f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "891c1456-8558-4bc1-afc2-72175cd7829e",
   "metadata": {},
   "source": [
    "## 2. 创建数据集"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "36384a25-0993-40c7-8fe7-3e333a59c4d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 设置随机数种子以便结果可复现\n",
    "torch.manual_seed(0)\n",
    "\n",
    "# 创建一些随机输入和输出数据\n",
    "input_size = 10\n",
    "output_size = 1\n",
    "batch_size = 5\n",
    "\n",
    "# 随机生成数据\n",
    "x = torch.randn(batch_size, input_size)\n",
    "y = torch.randn(batch_size, output_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afc125fb-9533-4a3a-9c99-113a91520403",
   "metadata": {},
   "source": [
    "## 3. 定义模型\n",
    "\n",
    "定义一个简单的线性模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "d97d4bb0-4539-4679-8b8c-dceb66703e17",
   "metadata": {},
   "outputs": [],
   "source": [
    "class SimpleModel(nn.Module):\n",
    "    def __init__(self, input_dim, output_dim):\n",
    "        super(SimpleModel, self).__init__()\n",
    "        self.linear = nn.Linear(input_dim, output_dim)\n",
    "\n",
    "    def forward(self, x):\n",
    "        return self.linear(x)\n",
    "\n",
    "\n",
    "model = SimpleModel(input_size, output_size)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a291b0a8-ba0d-4bb5-84e0-8cb8ae7d3c92",
   "metadata": {},
   "source": [
    "## 4. 设置损失函数和优化器"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f8b6518e-51ed-4133-93e4-b931fb091f7a",
   "metadata": {},
   "outputs": [],
   "source": [
    "criterion = nn.MSELoss()\n",
    "optimizer = optim.SGD(model.parameters(), lr=0.01)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bcc63c26-9da7-49cd-9dba-ae42974b4fca",
   "metadata": {},
   "source": [
    "## 5. 训练模型\n",
    "\n",
    "执行训练循环。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "6f47d066-e166-4953-9c08-13f02c4b2d3d",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch [10/100], Loss: 0.7263\n",
      "Epoch [20/100], Loss: 0.4859\n",
      "Epoch [30/100], Loss: 0.3933\n",
      "Epoch [40/100], Loss: 0.3320\n",
      "Epoch [50/100], Loss: 0.2846\n",
      "Epoch [60/100], Loss: 0.2459\n",
      "Epoch [70/100], Loss: 0.2137\n",
      "Epoch [80/100], Loss: 0.1865\n",
      "Epoch [90/100], Loss: 0.1632\n",
      "Epoch [100/100], Loss: 0.1432\n"
     ]
    }
   ],
   "source": [
    "epochs = 100\n",
    "\n",
    "for epoch in range(epochs):\n",
    "    # 前向传播\n",
    "    outputs = model(x)\n",
    "    loss = criterion(outputs, y)\n",
    "\n",
    "    # 反向传播和优化\n",
    "    optimizer.zero_grad()\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "\n",
    "    if (epoch+1) % 10 == 0:\n",
    "        print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c3702a02-7414-4807-b1ce-d5dd383ef740",
   "metadata": {},
   "source": [
    "## 6. 测试模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "40da7aee-ab73-436c-b547-c25a0b893c70",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Prediction for new data point: -0.15251392126083374\n"
     ]
    }
   ],
   "source": [
    "new_data_point = torch.randn(1, input_size)\n",
    "prediction = model(new_data_point)\n",
    "print(f'Prediction for new data point: {prediction.item()}')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37e9465c-32f1-4594-b60e-c867bd0d715f",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.9.16",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.16"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
